import json
import random
from Logger import logger

def load_jsonl(path):
    with open(path) as f:
        return [json.loads(line) for line in f]

def k_shot_split(data, k_shots=10):
    # 每个标签取 k_shots 个随机样本
    random.shuffle(data)
    label2data = {}
    for d in data:
        label = d['label']
        if label not in label2data:
            label2data[label] = []
        label2data[label].append(d)

    logger.info("Label distribution: {}".format(json.dumps({label: len(label2data[label]) for label in label2data}, indent=4)))

    train_data = []
    test_data = []
    for label in label2data:
        label_data = label2data[label]
        train_data.extend(label_data[:k_shots])
        test_data.extend(label_data[k_shots:])
    return train_data, test_data
